import os
import sys
import json
import math
import torch
import argparse
import numpy as np
import open3d as o3d
from io import BytesIO
from pytorch3d.io import IO
import matplotlib.pyplot as plt
from plyfile import PlyData, PlyElement
from PIL import Image, ImageDraw, ImageFont
from pytorch3d.structures import Pointclouds, Meshes
from pytorch3d.renderer import (
    PointLights,
    AmbientLights,
    RasterizationSettings,
    SoftPhongShader,
    MeshRasterizer,
    MeshRenderer,
    PointsRenderer,
    PointsRasterizationSettings,
    PointsRasterizer,
    AlphaCompositor,
    FoVPerspectiveCameras,
    PerspectiveCameras,
    look_at_view_transform,
)


parser = argparse.ArgumentParser(description="Process and align meshes to axes.")
parser.add_argument(
    "--scannet_file",
    type=str,
    default="",
    help="Path to the ScanNet scene file list.",
)
parser.add_argument(
    "--scan_dir",
    type=str,
    default="",
    help="Directory containing ScanNet scans.",
)
parser.add_argument(
    "--axis_alignment_info_file",
    type=str,
    default="",
    help="Path to the axis alignment info JSON file.",
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="",
    help="Directory containing ScanNet scans.",
)
parser.add_argument(
    "--pcd_dir",
    type=str,
    default="",
    help="Directory containing ScanNet scans.",
)
parser.add_argument(
    "--bbox_dir",
    type=str,
    default="",
    help="Directory containing bbox file.",
)
parser.add_argument(
    "--gt_bbox_dir",
    type=str,
    default="",
    help="Directory containing gt bbox file.",
)
parser.add_argument("--image_size", type=int, default=1024)
parser.add_argument("--dataset", type=str, default="scannet")



def read_file_to_list(file_path):
    with open(file_path, "r") as file:
        return sorted(file.read().splitlines())


def read_dict(file_path):
    with open(file_path) as fin:
        return json.load(fin)


def load_mesh_data(scan_data_file):
    """
    Load mesh data (vertices and colors) from a PLY file.

    Parameters:
    - scan_data_file: Path to the PLY file.

    Returns:
    - vertices: A numpy array of shape (N, 3) containing the mesh vertices.
    - colors: A numpy array of shape (N, 3) containing the mesh vertex colors.
    - faces: A numpy array of shape (M, 3) containing the mesh faces (if available).
    """
    data = PlyData.read(scan_data_file)

    # Extract vertex data
    x = np.asarray(data.elements[0].data["x"])
    y = np.asarray(data.elements[0].data["y"])
    z = np.asarray(data.elements[0].data["z"])
    red = np.asarray(data.elements[0].data["red"])
    green = np.asarray(data.elements[0].data["green"])
    blue = np.asarray(data.elements[0].data["blue"])
    vertices = np.stack([x, y, z], axis=1)
    colors = np.stack([red, green, blue], axis=1)

    # Read faces if available
    faces = None
    if len(data.elements) > 1:
        faces = np.asarray(data.elements[1].data["vertex_indices"])

    return vertices, colors, faces


def save_mesh(vertices, colors, faces, file_path):
    """
    Save the mesh vertices, colors, and faces to a PLY file.

    Parameters:
    - vertices: A numpy array of shape (N, 3) containing the mesh vertices.
    - colors: A numpy array of shape (N, 3) containing the vertex colors.
    - faces: A numpy array of shape (M, 3) containing the mesh faces.
    - file_path: The path where the PLY file will be saved.
    """
    # Create a structured array for the vertices and colors
    vertex_data = np.array(
        [(v[0], v[1], v[2], c[0], c[1], c[2]) for v, c in zip(vertices, colors)],
        dtype=[
            ("x", "f4"),
            ("y", "f4"),
            ("z", "f4"),
            ("red", "u1"),
            ("green", "u1"),
            ("blue", "u1"),
        ],
    )

    # Create a PlyElement for the vertices
    vertex_element = PlyElement.describe(vertex_data, "vertex")

    elements = [vertex_element]

    if faces is not None:
        # Create a structured array for the faces
        face_data = np.array(
            [(f,) for f in faces], dtype=[("vertex_indices", "i4", (3,))]
        )
        face_element = PlyElement.describe(face_data, "face")
        elements.append(face_element)

    # Write the vertices and faces to a PLY file
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    PlyData(elements, text=True).write(file_path)
    print(f"Saved mesh to {file_path}")
    return


def load_json(file_path):
    """Load data from a JSON file."""
    with open(file_path, "r") as file:
        return json.load(file)
    

def load_bboxes(room, bbox_dir):
    """Load bounding boxes (GT or predicted)."""
    bbox_file = os.path.join(bbox_dir, f"{room}.json")
    bboxes = load_json(bbox_file)
    return {int(bbox["bbox_id"]): bbox for bbox in bboxes}


def process_aligned_mesh(scene, remove_top=True):
    scan_data_file = os.path.join(args.scan_dir, scene, scene + "_vh_clean_2.ply")
    print(f"Processing {scan_data_file}")

    # Load mesh data
    vertices, colors, faces = load_mesh_data(scan_data_file)

    # Get the axis alignment matrix
    scans_axis_alignment_matrices = read_dict(args.axis_alignment_info_file)
    alignment_matrix = scans_axis_alignment_matrices[scene]
    alignment_matrix = np.array(alignment_matrix, dtype=np.float32).reshape(4, 4)

    # Transform the vertices
    pts = np.ones((vertices.shape[0], 4), dtype=vertices.dtype)
    pts[:, 0:3] = vertices
    aligned_vertices = np.dot(pts, alignment_matrix.transpose())[:, :3]

    # Ensure no NaN values are introduced after transformation
    assert np.sum(np.isnan(aligned_vertices)) == 0
    print(f"Aligned mesh to axes of {scene}")

    if not remove_top:
        # Save the aligned mesh to the output directory
        output_path = os.path.join(args.output_dir, f"{scene}.ply")
        save_mesh(aligned_vertices, colors, faces, output_path)
    else:
        threshold = 0.4  # height threshold
        z_max = np.max(aligned_vertices[:, 2])
        
        H_threshold = z_max - threshold
        valid_mask = aligned_vertices[:, 2] < H_threshold
        mask_idx = np.where(valid_mask)[0]

        # vertice
        filtered_verts = aligned_vertices[valid_mask]
        # color
        filtered_colors = colors[valid_mask]
        # face
        ## index re-mapping
        index_mapping = np.full(vertices.shape[0], -1, dtype=int)
        index_mapping[mask_idx] = np.arange(len(mask_idx))
        faces_array = np.array([np.array(f, dtype=int) for f in faces], dtype=int)
        valid_faces_mask = np.all(np.isin(faces_array, mask_idx), axis=1)
        filtered_faces_array = faces_array[valid_faces_mask]
        filtered_faces_array = index_mapping[filtered_faces_array]

        # filtered_faces = np.array([np.array(f, dtype=int) for f in filtered_faces_array], dtype=object)

        # save file
        filter_path = ""
        output_path_filter = os.path.join(filter_path, f"{scene}.ply")
        save_mesh(filtered_verts, filtered_colors, filtered_faces_array, output_path_filter)


def create_point_cloud(scan_pc, device):
    """
    Create a point cloud from scan data.

    Args:
        scan_pc (np.ndarray): The scan data containing points and colors.
        device (str): The device to use for computation.

    Returns:
        Pointclouds: The created point cloud.
    """
    points = torch.tensor(scan_pc[:, :3], dtype=torch.float32)
    colors = torch.tensor(scan_pc[:, 3:], dtype=torch.float32)
    point_cloud = Pointclouds(points=[points], features=[colors]).to(device)
    return point_cloud


def setup_camera(
    point_cloud,
    center,
    image_size,
    camera_distance_factor=1.0,
    camera_lift=1.0,
    camera_shift=1.0,
    camera_dist=8.0,
    device="cuda",
    calibrate=True,
    zoom_in=False
):
    """
    Set up the camera for rendering the point cloud.

    Args:
        point_cloud (Pointclouds): The point cloud to render.
        center (np.ndarray): The center of the point cloud.
        image_size (int): The size of the output image.
        camera_distance_factor (float): The factor to adjust camera distance.
        camera_lift (float): The lift to apply to the camera.
        device (str): The device to use for computation.
        calibrate (bool): Whether to calibrate the camera.

    Returns:
        PerspectiveCameras: The set up camera.
    """
    # Compute the bounding box of the point cloud
    min_bounds = point_cloud.points_padded().min(dim=1)[0]
    max_bounds = point_cloud.points_padded().max(dim=1)[0]
    bound_x, bound_y, bound_z = (max_bounds - min_bounds)[0]
    max_bound = max(bound_x, bound_y)

    center = torch.tensor(center, dtype=torch.float32)
    # center[2] += camera_lift
    # camera_position = center + camera_distance_factor * (center - anchor_bbox_3d)
    # camera_position_top = center + torch.tensor([0, 0, camera_lift])
    # camera_position_up = center + torch.tensor([0, bound_y/2 + camera_shift, camera_lift])
    # camera_position_down = center + torch.tensor([0, -(bound_y/2 + camera_shift), camera_lift])
    # camera_position_left = center + torch.tensor([-(bound_x/2 + camera_shift), 0, camera_lift])
    # camera_position_right = center + torch.tensor([bound_x/2 + camera_shift, 0, camera_lift])

    '''
    corners_2d = cameras.transform_points_screen(torch.tensor(corners).cuda(), image_size=(image_size, image_size))
    '''

    # camera parameters
    focal_length = torch.tensor([[2.0, 2.0]]).to(point_cloud.device)  # Initial focal length, shape (1, 2)
    principal_point = torch.tensor([[0.0, 0.0]]).to(point_cloud.device)  # Initial principal point, shape (1, 2)

    # calculate camera FOV
    fov_x = focal_length_to_fov(focal_length[:, 0])
    fov_x_rad = math.radians(fov_x)
    fov_y = focal_length_to_fov(focal_length[:, 1])
    fov_y_rad = math.radians(fov_y)

    camera_dist_top = (max_bound / 2) / math.tan(fov_y_rad / 2)
    # camera_dist_left_right = bound_z * math.sqrt(2) + (bound_x-bound_z)
    if zoom_in:
        camera_dist_left_right = (bound_x / 2) * math.sqrt(2)  # 45°
        camera_dist_up_down = (bound_y / 2) * math.sqrt(2)  # 45°
    else:
        camera_dist_left_right = camera_dist_top
        camera_dist_up_down = camera_dist_top

    R_top, T_top = look_at_view_transform(
        dist=camera_dist_top,
        elev=0,
        azim=0,
        at=center.unsqueeze(0),
        # eye=camera_position_top.unsqueeze(0),
        up=((0, 1, 0),),
    )

    R_left, T_left = look_at_view_transform(
        dist=camera_dist_left_right,  # camera_dist_top  v1
        elev=0,
        azim=-45,
        at=center.unsqueeze(0),
        # eye=camera_position_left.unsqueeze(0),
        up=((1, 0, 0),),
    )

    R_right, T_right = look_at_view_transform(
        dist=camera_dist_left_right,  # camera_dist_top
        elev=0,
        azim=45,
        at=center.unsqueeze(0),
        # eye=camera_position_right.unsqueeze(0),
        up=((-1, 0, 0),),
    )

    R_up, T_up = look_at_view_transform(
        dist=camera_dist_up_down,  # camera_dist_top
        elev=45,
        azim=0,
        at=center.unsqueeze(0),
        # eye=camera_position_up.unsqueeze(0),
        up=((0, -1, 0),),
    )
    
    R_down, T_down = look_at_view_transform(
        dist=camera_dist_up_down,  # camera_dist_top
        elev=-45,
        azim=0,
        at=center.unsqueeze(0),
        # eye=camera_position_down.unsqueeze(0),
        up=((0, 1, 0),),
    )

    # R, T = look_at_view_transform(
    #     at=center.unsqueeze(0),
    #     eye=camera_position_left.unsqueeze(0),
    #     up=((1, 0, 0),),
    # )

    cameras_top = PerspectiveCameras(device=device, R=R_top, T=T_top, focal_length=focal_length, principal_point=principal_point,)
    cameras_down = PerspectiveCameras(device=device, R=R_down, T=T_down, focal_length=focal_length, principal_point=principal_point,)
    cameras_up = PerspectiveCameras(device=device, R=R_up, T=T_up, focal_length=focal_length, principal_point=principal_point,)
    cameras_left = PerspectiveCameras(device=device, R=R_left, T=T_left, focal_length=focal_length, principal_point=principal_point,)
    cameras_right = PerspectiveCameras(device=device, R=R_right, T=T_right, focal_length=focal_length, principal_point=principal_point,)
    
    # cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=focal_length, principal_point=principal_point,)

    if calibrate:
        if isinstance(image_size, int):
            image_size_tensor = torch.tensor(
                [[image_size, image_size]]
            )  # Convert integer to 2D tensor
        assert image_size_tensor.shape[-1] == 2

        # Get the projection of the point cloud
        points_2d = cameras.transform_points_screen(
            point_cloud.points_padded(), image_size=image_size_tensor
        )
        points_2d = points_2d[..., :2]

        # Compute the bounding box of the projected points
        min_proj = points_2d.min(dim=1)[0]
        max_proj = points_2d.max(dim=1)[0]

        # Adjust focal length and principal point to ensure all points are within the image
        new_focal_length = (
            focal_length
            * (max_proj - min_proj).max()
            / image_size_tensor.to(point_cloud.device)
        )
        new_principal_point = (min_proj + max_proj) / 2

        # Update camera intrinsics
        cameras = PerspectiveCameras(
            device=device,
            R=R,
            T=T,
            focal_length=new_focal_length,
            principal_point=new_principal_point,  # Ensure principal point is 2D
        )
    return cameras_top, cameras_down, cameras_up, cameras_left, cameras_right


def focal_length_to_fov(focal_length):
    """
    For pytorch3d NDC coordinate system
    """
    return torch.rad2deg(2 * torch.atan(1 / focal_length))


# def fov_to_focal_length(fov):
#     return 1 / (2 * torch.tan(torch.deg2rad(fov / 2)))


def render_point_cloud(point_cloud, cameras, image_size, device="cuda"):
    """
    Render the point cloud.

    Args:
        point_cloud (Pointclouds): The point cloud to render.
        cameras (PerspectiveCameras): The camera settings.
        image_size (int): The size of the output image.
        device (str): The device to use for computation.

    Returns:
        np.ndarray: The rendered image.
        PointsRasterizer: The rasterizer used for rendering.
    """
    raster_settings = PointsRasterizationSettings(
        image_size=image_size, radius=0.008, points_per_pixel=10
    )
    rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
    renderer = PointsRenderer(
        rasterizer=rasterizer, compositor=AlphaCompositor(background_color=(1.0, 1.0, 1.0))
    )
    images = renderer(point_cloud)
    image_np = images[0, ..., :3].cpu().numpy()
    image_255 = (image_np * 255).astype(np.uint8)
    color_image = Image.fromarray(image_255)
    return color_image


def render_mesh(mesh, cameras, image_size, device="cuda"):
    verts = mesh.verts_packed()

    mesh_center = verts.mean(dim=0)  # (x, y, z)

    mesh_min = verts.min(dim=0)[0]
    mesh_max = verts.max(dim=0)[0]
    mesh_scale = (mesh_max - mesh_min).max()

    light_offset = 1.5 * mesh_scale
    light_position = mesh_center + torch.tensor([0.0, 0.0, light_offset], device=mesh.device)
    light_position = [light_position.tolist()]

    camera_position = cameras.get_camera_center() 
    
    lights = PointLights(device=device, 
                         location=camera_position, 
                         ambient_color=((0.6, 0.6, 0.6),),
                         diffuse_color=((0.4, 0.4, 0.4),),
                         specular_color=((0.0, 0.0, 0.0),),)
    # ambient_light = AmbientLights(device=device, ambient_color=((0.6, 0.6, 0.6),))

    raster_settings = RasterizationSettings(
        image_size=image_size,
        blur_radius=0.0,
        faces_per_pixel=1,
    )

    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        ),
        shader=SoftPhongShader(
            device=device,
            cameras=cameras,
            lights=lights
        )
    )

    images = renderer(mesh)
    image_np = images[0, ..., :3].cpu().numpy()
    image_255 = (image_np * 255).astype(np.uint8)
    color_image = Image.fromarray(image_255)
    return color_image


# def render_camera_image(cameras, image_size):
#     camera_matrix = cameras.get_world_to_view_transform().get_matrix()
#     # print('camera', camera_matrix.shape)

#     pts_view = cameras.get_world_to_view_transform().transform_points(pts_world)

#     image_np, rasterizer = render_point_cloud(point_cloud, cameras, image_size=image_size, device="cuda")
#     image_255 = (image_np * 255).astype(np.uint8)
#     color_image = Image.fromarray(image_255)
#     return color_image


def draw_ids(draw, bboxes, cameras, image_size, font):
    """
    Draw object IDs on the image.

    Args:
        draw (ImageDraw): ImageDraw object.
        bboxes (list): List of bounding boxes.
        cameras (PerspectiveCameras): Camera settings.
        image_size (int): Size of the output image.
        font (ImageFont): Font for drawing text.
    """
    all_id = []
    for bbox in bboxes:
        bbox_id = bbox["bbox_id"]
        x, y, z, w, l, h = bbox["bbox_3d"]

        # Define the eight corners of the 3D bounding box
        corners = [
            [x - w / 2, y - l / 2, z - h / 2],
            [x - w / 2, y + l / 2, z - h / 2],
            [x + w / 2, y - l / 2, z - h / 2],
            [x + w / 2, y + l / 2, z - h / 2],
            [x - w / 2, y - l / 2, z + h / 2],
            [x - w / 2, y + l / 2, z + h / 2],
            [x + w / 2, y - l / 2, z + h / 2],
            [x + w / 2, y + l / 2, z + h / 2],
        ]

        # Project the 3D corners to the 2D image plane
        corners_2d = cameras.transform_points_screen(
            torch.tensor(corners).cuda(), image_size=(image_size, image_size)
        )
        corners_2d = corners_2d[..., :2].cpu().numpy()

        # Check if each corner is within the image boundaries
        valid_corners = [
            (0 <= x < image_size and 0 <= y < image_size) for x, y in corners_2d
        ]

        # Skip drawing if all corners are out of image boundaries
        if not any(valid_corners):
            continue

        # Draw the label and bbox_id
        stat = draw_label(draw, corners_2d, bbox_id, font, image_size)
        if stat:
            all_id.append(bbox_id)
    return all_id


def draw_ids_combine_overlap(draw, bboxes, cameras, image_size, font, thershold_2d=20):
    """
    Draw object IDs on the image. Combine IDs that are too close.

    Args:
        draw (ImageDraw): ImageDraw object.
        bboxes (list): List of bounding boxes.
        cameras (PerspectiveCameras): Camera settings.
        image_size (int): Size of the output image.
        font (ImageFont): Font for drawing text.
        threshold (float): Distance threshold to combine IDs.
    """
    corners_2d_info = []
    all_id = []
    for bbox in bboxes:
        bbox_id = bbox["bbox_id"]
        x, y, z, w, l, h = bbox["bbox_3d"]

        # Define the eight corners of the 3D bounding box
        corners = [
            [x - w / 2, y - l / 2, z - h / 2],
            [x - w / 2, y + l / 2, z - h / 2],
            [x + w / 2, y - l / 2, z - h / 2],
            [x + w / 2, y + l / 2, z - h / 2],
            [x - w / 2, y - l / 2, z + h / 2],
            [x - w / 2, y + l / 2, z + h / 2],
            [x + w / 2, y - l / 2, z + h / 2],
            [x + w / 2, y + l / 2, z + h / 2],
        ]

        # Project the 3D corners to the 2D image plane
        corners_2d = cameras.transform_points_screen(
            torch.tensor(corners).cuda(), image_size=(image_size, image_size)
        )
        corners_2d = corners_2d[..., :2].cpu().numpy()

        # Check if each corner is within the image boundaries
        valid_corners = [
            (0 <= x < image_size and 0 <= y < image_size) for x, y in corners_2d
        ]

        # Skip drawing if all corners are out of image boundaries
        if not any(valid_corners):
            continue
        
        corners_2d_info.append(
            {
                "bbox_id": bbox_id, 
                "corners_2d": corners_2d, 
                "centers_2d": np.mean(corners_2d, axis=0)
            }
        )

        # Draw the label and bbox_id
        # stat = draw_label(draw, corners_2d, bbox_id, font, image_size)
        # if stat:
        #     all_id.append(bbox_id)
    
    combined_result = process_2d_centers(corners_2d_info, threshold_2d=thershold_2d)

    for res in combined_result.values():
        group_bbox_ids = res["bbox_ids"]
        corners_2d = np.array(res["merged_center_2d"])
        bbox_id = group_bbox_ids[0]
        stat = draw_label(draw, corners_2d.reshape(1, 2), bbox_id, font, image_size)
        if stat:
            all_id.append(bbox_id)
    
    return all_id


class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
    
    def find(self, x):
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x
    
    def union(self, x, y):
        self.parent[self.find(x)] = self.find(y)


def process_2d_centers(corners_2d_info, threshold_2d=30):
    n = len(corners_2d_info)
    uf = UnionFind(n)
    
    centers = [info['centers_2d'] for info in corners_2d_info]
    ids = [info['bbox_id'] for info in corners_2d_info]
    
    for i in range(n):
        for j in range(i + 1, n):
            x1, y1 = centers[i]
            x2, y2 = centers[j]
            dist = math.hypot(x1 - x2, y1 - y2)
            if dist < threshold_2d:
                uf.union(i, j)
    
    groups = {}
    for i in range(n):
        root = uf.find(i)
        groups.setdefault(root, []).append(i)
    
    result = {}
    for group_id, indices in groups.items():
        group_ids = [ids[i] for i in indices]
        group_centers = [centers[i] for i in indices]
        avg_center = [sum(c[i] for c in group_centers)/len(group_centers) for i in range(2)]
        result[group_id] = {
            'bbox_ids': group_ids,
            'merged_center_2d': avg_center
        }

    return result


def draw_label(draw, corners_2d, bbox_id, font, image_size):
    """
    Draw label and bbox_id at the center of the top face of the bounding box.

    Args:
        draw (ImageDraw): ImageDraw object.
        corners_2d (array): 2D coordinates of the bounding box corners.
        bbox_id (int): Bounding box ID.
        font (ImageFont): Font for drawing text.
        image_size (int): Size of the output image.
    """
    # Find the center of the top face
    # center_x = int(
    #     (corners_2d[4][0] + corners_2d[5][0] + corners_2d[6][0] + corners_2d[7][0]) / 4
    # )
    # center_y = int(
    #     (corners_2d[4][1] + corners_2d[5][1] + corners_2d[6][1] + corners_2d[7][1]) / 4
    # )

    # center of all faces
    center_x, center_y = np.mean(corners_2d, axis=0)
    center_x, center_y = int(center_x), int(center_y)

    if 0 <= center_x < image_size and 0 <= center_y < image_size:
        text = f"{bbox_id}"
        text_bbox = draw.textbbox((0, 0), text, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
        background_x0 = center_x - text_width // 2  #- 2
        background_y0 = center_y - text_height // 2 + 4  # adjusted
        background_x1 = center_x + text_width // 2 #+ 2
        background_y1 = center_y + text_height // 2 + 5  # adjusted
        draw.rectangle(
            [background_x0, background_y0, background_x1, background_y1],
            fill=(255, 255, 255),
        )
        draw.text(
            (center_x - text_width // 2, center_y - text_height // 2),
            text,
            font=font,
            fill=(255, 0, 0),
        )
        return True

    return False


def load_scan_pc(scene, ply_path):
    aligned_ply_file = os.path.join(ply_path, f"{scene}.ply")
    pcd = o3d.io.read_point_cloud(aligned_ply_file)
    pc = np.asarray(pcd.points)
    color = np.asarray(pcd.colors)

    scan_pc = np.concatenate((pc, color), axis=1).astype("float32")
    center = np.mean(scan_pc[:, :3], axis=0)

    return scan_pc, center


def tag_view_image(output_image, annotation, save_path):
    width, height = output_image.size
    fig, ax = plt.subplots(figsize=(width / 100, height / 100))
    ax.imshow(output_image)
    ax.axis("off")

    ax.annotate(
        f"{annotation}",
        (5, 5),
        color="red",
        fontsize=0.05 * height, 
        ha="left",
        va="top",
        fontweight="bold",
    )

    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)

    buf = BytesIO()
    plt.savefig(buf, format="jpg", bbox_inches="tight", pad_inches=0)
    plt.close(fig)
    buf.seek(0)

    output_image = Image.open(buf)

    output_image.save(save_path)


if __name__ == "__main__":
    args = parser.parse_args()
    scans = read_file_to_list(args.scannet_file)
    print(f"Loaded {len(scans)} scenes")

    for scene in scans:
        print(f"processing {scene}")
       
        pass
